"tests/vscode:/vscode.git/clone" did not exist on "38a1f5821acb516e6036df45deaa39185b88de6e"
Commit 71660f1c authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

refine prediction logic. (#395)

* refine prediction logic.

* fix test.

* fix out_len in training score of Dart.

* improve predict speed for high dimension data.
parent f1ffc10d
...@@ -110,27 +110,29 @@ public: ...@@ -110,27 +110,29 @@ public:
*/ */
virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0; virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0;
virtual int NumPredictOneRow(int num_iteration, int is_pred_leaf) const = 0;
/*! /*!
* \brief Prediction for one record, not sigmoid transform * \brief Prediction for one record, not sigmoid transform
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result for this record * \param output Prediction result for this record
*/ */
virtual std::vector<double> PredictRaw(const double* feature_values) const = 0; virtual void PredictRaw(const double* feature_values, double* output) const = 0;
/*! /*!
* \brief Prediction for one record, sigmoid transformation will be used if needed * \brief Prediction for one record, sigmoid transformation will be used if needed
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Prediction result for this record * \param output Prediction result for this record
*/ */
virtual std::vector<double> Predict(const double* feature_values) const = 0; virtual void Predict(const double* feature_values, double* output) const = 0;
/*! /*!
* \brief Prediction for one record with leaf index * \brief Prediction for one record with leaf index
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \return Predicted leaf index for this record * \param output Prediction result for this record
*/ */
virtual std::vector<int> PredictLeafIndex( virtual void PredictLeafIndex(
const double* feature_values) const = 0; const double* feature_values, double* output) const = 0;
/*! /*!
* \brief Dump model to json format string * \brief Dump model to json format string
...@@ -185,6 +187,12 @@ public: ...@@ -185,6 +187,12 @@ public:
*/ */
virtual int NumberOfTotalModel() const = 0; virtual int NumberOfTotalModel() const = 0;
/*!
* \brief Get number of trees per iteration
* \return Number of trees per iteration
*/
virtual int NumTreePerIteration() const = 0;
/*! /*!
* \brief Get number of classes * \brief Get number of classes
* \return Number of classes * \return Number of classes
...@@ -192,9 +200,11 @@ public: ...@@ -192,9 +200,11 @@ public:
virtual int NumberOfClasses() const = 0; virtual int NumberOfClasses() const = 0;
/*! /*!
* \brief Set number of used model for prediction * \brief Initial work for the prediction
* \param num_iteration number of used iteration
* \return the feature indices mapper
*/ */
virtual void SetNumIterationForPred(int num_iteration) = 0; virtual std::vector<int> InitPredict(int num_iteration) = 0;
/*! /*!
* \brief Name of submodel * \brief Name of submodel
......
...@@ -22,7 +22,7 @@ const score_t kEpsilon = 1e-15f; ...@@ -22,7 +22,7 @@ const score_t kEpsilon = 1e-15f;
using ReduceFunction = std::function<void(const char*, char*, int)>; using ReduceFunction = std::function<void(const char*, char*, int)>;
using PredictFunction = using PredictFunction =
std::function<std::vector<double>(const std::vector<std::pair<int, double>>&)>; std::function<void(const std::vector<std::pair<int, double>>&, double* output)>;
#define NO_SPECIFIC (-1) #define NO_SPECIFIC (-1)
......
...@@ -34,8 +34,7 @@ public: ...@@ -34,8 +34,7 @@ public:
* \brief Calcaluting and printing metric result * \brief Calcaluting and printing metric result
* \param score Current prediction score * \param score Current prediction score
*/ */
virtual std::vector<double> Eval(const double* score, const ObjectiveFunction* objective, virtual std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const = 0;
int num_tree_per_iteration) const = 0;
Metric() = default; Metric() = default;
/*! \brief Disable copy */ /*! \brief Disable copy */
......
...@@ -39,14 +39,12 @@ public: ...@@ -39,14 +39,12 @@ public:
virtual bool SkipEmptyClass() const { return false; } virtual bool SkipEmptyClass() const { return false; }
virtual int numTreePerIteration() const { return 1; } virtual int NumTreePerIteration() const { return 1; }
virtual std::vector<double> ConvertOutput(std::vector<double>& input) const { virtual int NumPredictOneRow() const { return 1; }
return input;
}
virtual double ConvertOutput(double input) const { virtual void ConvertOutput(const double* input, double* output) const {
return input; output[0] = input[0];
} }
virtual std::string ToString() const = 0; virtual std::string ToString() const = 0;
......
...@@ -111,6 +111,13 @@ public: ...@@ -111,6 +111,13 @@ public:
shrinkage_ *= rate; shrinkage_ *= rate;
} }
inline void ReMapFeature(const std::vector<int>& feature_mapper) {
mapped_feature_ = split_feature_;
for (int i = 0; i < num_leaves_ - 1; ++i) {
mapped_feature_[i] = feature_mapper[split_feature_[i]];
}
}
/*! \brief Serialize this object to string*/ /*! \brief Serialize this object to string*/
std::string ToString(); std::string ToString();
...@@ -194,9 +201,10 @@ private: ...@@ -194,9 +201,10 @@ private:
std::vector<int> leaf_depth_; std::vector<int> leaf_depth_;
double shrinkage_; double shrinkage_;
bool has_categorical_; bool has_categorical_;
/*! \brief buffer of mapped split_feature_ */
std::vector<int> mapped_feature_;
}; };
inline double Tree::Predict(const double* feature_values) const { inline double Tree::Predict(const double* feature_values) const {
if (num_leaves_ > 1) { if (num_leaves_ > 1) {
int leaf = GetLeaf(feature_values); int leaf = GetLeaf(feature_values);
...@@ -217,15 +225,27 @@ inline int Tree::PredictLeafIndex(const double* feature_values) const { ...@@ -217,15 +225,27 @@ inline int Tree::PredictLeafIndex(const double* feature_values) const {
inline int Tree::GetLeaf(const double* feature_values) const { inline int Tree::GetLeaf(const double* feature_values) const {
int node = 0; int node = 0;
if (has_categorical_) {
while (node >= 0) { while (node >= 0) {
if (decision_funs[decision_type_[node]]( if (decision_funs[decision_type_[node]](
feature_values[split_feature_[node]], feature_values[mapped_feature_[node]],
threshold_[node])) { threshold_[node])) {
node = left_child_[node]; node = left_child_[node];
} else { } else {
node = right_child_[node]; node = right_child_[node];
} }
} }
} else {
while (node >= 0) {
if (NumericalDecision<double>(
feature_values[mapped_feature_[node]],
threshold_[node])) {
node = left_child_[node];
} else {
node = right_child_[node];
}
}
}
return ~node; return ~node;
} }
......
...@@ -377,18 +377,18 @@ inline void Softmax(std::vector<double>* p_rec) { ...@@ -377,18 +377,18 @@ inline void Softmax(std::vector<double>* p_rec) {
} }
} }
inline void Softmax(double* rec, int len) { inline void Softmax(const double* input, double* output, int len) {
double wmax = rec[0]; double wmax = input[0];
for (int i = 1; i < len; ++i) { for (int i = 1; i < len; ++i) {
wmax = std::max(rec[i], wmax); wmax = std::max(input[i], wmax);
} }
double wsum = 0.0f; double wsum = 0.0f;
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
rec[i] = std::exp(rec[i] - wmax); output[i] = std::exp(input[i] - wmax);
wsum += rec[i]; wsum += output[i];
} }
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
rec[i] /= static_cast<double>(wsum); output[i] /= static_cast<double>(wsum);
} }
} }
......
...@@ -54,8 +54,7 @@ void Application::LoadParameters(int argc, char** argv) { ...@@ -54,8 +54,7 @@ void Application::LoadParameters(int argc, char** argv) {
continue; continue;
} }
params[key] = value; params[key] = value;
} } else {
else {
Log::Warning("Unknown parameter in command line: %s", argv[i]); Log::Warning("Unknown parameter in command line: %s", argv[i]);
} }
} }
...@@ -86,8 +85,7 @@ void Application::LoadParameters(int argc, char** argv) { ...@@ -86,8 +85,7 @@ void Application::LoadParameters(int argc, char** argv) {
if (params.count(key) == 0) { if (params.count(key) == 0) {
params[key] = value; params[key] = value;
} }
} } else {
else {
Log::Warning("Unknown parameter in config file: %s", line.c_str()); Log::Warning("Unknown parameter in config file: %s", line.c_str());
} }
} }
...@@ -110,7 +108,7 @@ void Application::LoadData() { ...@@ -110,7 +108,7 @@ void Application::LoadData() {
PredictFunction predict_fun = nullptr; PredictFunction predict_fun = nullptr;
// need to continue training // need to continue training
if (boosting_->NumberOfTotalModel() > 0) { if (boosting_->NumberOfTotalModel() > 0) {
predictor.reset(new Predictor(boosting_.get(), true, false)); predictor.reset(new Predictor(boosting_.get(), -1, true, false));
predict_fun = predictor->GetPredictFunction(); predict_fun = predictor->GetPredictFunction();
} }
...@@ -121,7 +119,7 @@ void Application::LoadData() { ...@@ -121,7 +119,7 @@ void Application::LoadData() {
} }
DatasetLoader dataset_loader(config_.io_config, predict_fun, DatasetLoader dataset_loader(config_.io_config, predict_fun,
boosting_->NumberOfClasses(), config_.io_config.data_filename.c_str()); config_.boosting_config.num_class, config_.io_config.data_filename.c_str());
// load Training data // load Training data
if (config_.is_parallel_find_bin) { if (config_.is_parallel_find_bin) {
// load data for parallel training // load data for parallel training
...@@ -241,9 +239,8 @@ void Application::Train() { ...@@ -241,9 +239,8 @@ void Application::Train() {
void Application::Predict() { void Application::Predict() {
boosting_->SetNumIterationForPred(config_.io_config.num_iteration_predict);
// create predictor // create predictor
Predictor predictor(boosting_.get(), config_.io_config.is_predict_raw_score, Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score,
config_.io_config.is_predict_leaf_index); config_.io_config.is_predict_leaf_index);
predictor.Predict(config_.io_config.data_filename.c_str(), predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header); config_.io_config.output_result.c_str(), config_.io_config.has_header);
......
...@@ -26,39 +26,42 @@ public: ...@@ -26,39 +26,42 @@ public:
/*! /*!
* \brief Constructor * \brief Constructor
* \param boosting Input boosting model * \param boosting Input boosting model
* \param num_iteration Number of boosting round
* \param is_raw_score True if need to predict result with raw score * \param is_raw_score True if need to predict result with raw score
* \param is_predict_leaf_index True if output leaf index instead of prediction score * \param is_predict_leaf_index True if output leaf index instead of prediction score
*/ */
Predictor(const Boosting* boosting, bool is_raw_score, bool is_predict_leaf_index) { Predictor(Boosting* boosting, int num_iteration,
bool is_raw_score, bool is_predict_leaf_index) {
feature_mapper_ = boosting->InitPredict(num_iteration);
boosting_ = boosting; boosting_ = boosting;
num_features_ = boosting_->MaxFeatureIdx() + 1; num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index);
#pragma omp parallel
#pragma omp master num_total_features_ = static_cast<int>(feature_mapper_.size());
{ num_used_features_ = 1;
num_threads_ = omp_get_num_threads(); for (auto fidx : feature_mapper_) {
} num_used_features_ = std::max(num_used_features_, fidx + 1);
for (int i = 0; i < num_threads_; ++i) {
features_.push_back(std::vector<double>(num_features_));
} }
features_.shrink_to_fit();
features_ = std::vector<double>(num_used_features_);
if (is_predict_leaf_index) { if (is_predict_leaf_index) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
const int tid = PutFeatureValuesToBuffer(features); PutFeatureValuesToBuffer(features);
// get result for leaf index // get result for leaf index
auto result = boosting_->PredictLeafIndex(features_[tid].data()); boosting_->PredictLeafIndex(features_.data(), output);
return std::vector<double>(result.begin(), result.end());
}; };
} else { } else {
if (is_raw_score) { if (is_raw_score) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
const int tid = PutFeatureValuesToBuffer(features); PutFeatureValuesToBuffer(features);
// get result without sigmoid transformation // get result without sigmoid transformation
return boosting_->PredictRaw(features_[tid].data()); boosting_->PredictRaw(features_.data(), output);
}; };
} else { } else {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
const int tid = PutFeatureValuesToBuffer(features); PutFeatureValuesToBuffer(features);
return boosting_->Predict(features_[tid].data()); boosting_->Predict(features_.data(), output);
}; };
} }
} }
...@@ -81,16 +84,16 @@ public: ...@@ -81,16 +84,16 @@ public:
void Predict(const char* data_filename, const char* result_filename, bool has_header) { void Predict(const char* data_filename, const char* result_filename, bool has_header) {
FILE* result_file; FILE* result_file;
#ifdef _MSC_VER #ifdef _MSC_VER
fopen_s(&result_file, result_filename, "w"); fopen_s(&result_file, result_filename, "w");
#else #else
result_file = fopen(result_filename, "w"); result_file = fopen(result_filename, "w");
#endif #endif
if (result_file == NULL) { if (result_file == NULL) {
Log::Fatal("Prediction results file %s doesn't exist", data_filename); Log::Fatal("Prediction results file %s doesn't exist", data_filename);
} }
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, has_header, num_features_, boosting_->LabelIdx())); auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, has_header, num_used_features_, boosting_->LabelIdx()));
if (parser == nullptr) { if (parser == nullptr) {
Log::Fatal("Could not recognize the data format of data file %s", data_filename); Log::Fatal("Could not recognize the data format of data file %s", data_filename);
...@@ -108,52 +111,47 @@ public: ...@@ -108,52 +111,47 @@ public:
[this, &parser_fun, &result_file] [this, &parser_fun, &result_file]
(data_size_t, const std::vector<std::string>& lines) { (data_size_t, const std::vector<std::string>& lines) {
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, double>> oneline_features;
std::vector<std::string> pred_result(lines.size(), "");
OMP_INIT_EX();
#pragma omp parallel for schedule(static) private(oneline_features)
for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) { for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
OMP_LOOP_EX_BEGIN();
oneline_features.clear(); oneline_features.clear();
// parser // parser
parser_fun(lines[i].c_str(), &oneline_features); parser_fun(lines[i].c_str(), &oneline_features);
// predict // predict
pred_result[i] = Common::Join<double>(predict_fun_(oneline_features), "\t"); std::vector<double> result(num_pred_one_row_);
OMP_LOOP_EX_END(); predict_fun_(oneline_features, result.data());
} auto str_result = Common::Join<double>(result, "\t");
OMP_THROW_EX(); fprintf(result_file, "%s\n", str_result.c_str());
for (size_t i = 0; i < pred_result.size(); ++i) {
fprintf(result_file, "%s\n", pred_result[i].c_str());
} }
}; };
TextReader<data_size_t> predict_data_reader(data_filename, has_header); TextReader<data_size_t> predict_data_reader(data_filename, has_header);
predict_data_reader.ReadAllAndProcessParallel(process_fun); predict_data_reader.ReadAllAndProcessParallel(process_fun);
fclose(result_file); fclose(result_file);
} }
private: private:
int PutFeatureValuesToBuffer(const std::vector<std::pair<int, double>>& features) { void PutFeatureValuesToBuffer(const std::vector<std::pair<int, double>>& features) {
int tid = omp_get_thread_num(); std::memset(features_.data(), 0, sizeof(double)*num_used_features_);
// init feature value
std::memset(features_[tid].data(), 0, sizeof(double)*num_features_);
// put feature value // put feature value
for (const auto& p : features) { int loop_size = static_cast<int>(features.size());
if (p.first < num_features_) { #pragma omp parallel for schedule(static, 512) if(loop_size >= 1024)
features_[tid][p.first] = p.second; for (int i = 0; i < loop_size; ++i) {
if (features[i].first >= num_total_features_) continue;
auto fidx = feature_mapper_[features[i].first];
if (fidx >= 0) {
features_[fidx] = features[i].second;
} }
} }
return tid;
} }
/*! \brief Boosting model */ /*! \brief Boosting model */
const Boosting* boosting_; const Boosting* boosting_;
/*! \brief Buffer for feature values */ /*! \brief Buffer for feature values */
std::vector<std::vector<double>> features_; std::vector<double> features_;
/*! \brief Number of features */ /*! \brief Number of features */
int num_features_; int num_used_features_;
/*! \brief Number of threads */
int num_threads_;
/*! \brief function for prediction */ /*! \brief function for prediction */
PredictFunction predict_fun_; PredictFunction predict_fun_;
int num_pred_one_row_;
std::vector<int> feature_mapper_;
int num_total_features_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -85,7 +85,7 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -85,7 +85,7 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
num_tree_per_iteration_ = num_class_; num_tree_per_iteration_ = num_class_;
if (objective_function_ != nullptr) { if (objective_function_ != nullptr) {
is_constant_hessian_ = objective_function_->IsConstantHessian(); is_constant_hessian_ = objective_function_->IsConstantHessian();
num_tree_per_iteration_ = objective_function_->numTreePerIteration(); num_tree_per_iteration_ = objective_function_->NumTreePerIteration();
} else { } else {
is_constant_hessian_ = false; is_constant_hessian_ = false;
} }
...@@ -525,7 +525,7 @@ std::string GBDT::OutputMetric(int iter) { ...@@ -525,7 +525,7 @@ std::string GBDT::OutputMetric(int iter) {
if (need_output) { if (need_output) {
for (auto& sub_metric : training_metrics_) { for (auto& sub_metric : training_metrics_) {
auto name = sub_metric->GetName(); auto name = sub_metric->GetName();
auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_, num_tree_per_iteration_); auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_);
for (size_t k = 0; k < name.size(); ++k) { for (size_t k = 0; k < name.size(); ++k) {
std::stringstream tmp_buf; std::stringstream tmp_buf;
tmp_buf << "Iteration:" << iter tmp_buf << "Iteration:" << iter
...@@ -543,8 +543,7 @@ std::string GBDT::OutputMetric(int iter) { ...@@ -543,8 +543,7 @@ std::string GBDT::OutputMetric(int iter) {
for (size_t i = 0; i < valid_metrics_.size(); ++i) { for (size_t i = 0; i < valid_metrics_.size(); ++i) {
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) { for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score(), auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score(),
objective_function_, objective_function_);
num_tree_per_iteration_);
auto name = valid_metrics_[i][j]->GetName(); auto name = valid_metrics_[i][j]->GetName();
for (size_t k = 0; k < name.size(); ++k) { for (size_t k = 0; k < name.size(); ++k) {
std::stringstream tmp_buf; std::stringstream tmp_buf;
...@@ -583,8 +582,7 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const { ...@@ -583,8 +582,7 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const {
std::vector<double> ret; std::vector<double> ret;
if (data_idx == 0) { if (data_idx == 0) {
for (auto& sub_metric : training_metrics_) { for (auto& sub_metric : training_metrics_) {
auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_, auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_);
num_tree_per_iteration_);
for (auto score : scores) { for (auto score : scores) {
ret.push_back(score); ret.push_back(score);
} }
...@@ -593,8 +591,7 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const { ...@@ -593,8 +591,7 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const {
auto used_idx = data_idx - 1; auto used_idx = data_idx - 1;
for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) { for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
auto test_scores = valid_metrics_[used_idx][j]->Eval(valid_score_updater_[used_idx]->score(), auto test_scores = valid_metrics_[used_idx][j]->Eval(valid_score_updater_[used_idx]->score(),
objective_function_, objective_function_);
num_tree_per_iteration_);
for (auto score : test_scores) { for (auto score : test_scores) {
ret.push_back(score); ret.push_back(score);
} }
...@@ -626,11 +623,12 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) { ...@@ -626,11 +623,12 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
if (objective_function_ != nullptr) { if (objective_function_ != nullptr) {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
std::vector<double> tmp_result(num_class_); std::vector<double> tree_pred(num_tree_per_iteration_);
for (int j = 0; j < num_tree_per_iteration_; ++j) { for (int j = 0; j < num_tree_per_iteration_; ++j) {
tmp_result[j] = raw_scores[j * num_data + i]; tree_pred[j] = raw_scores[j * num_data + i];
} }
tmp_result = objective_function_->ConvertOutput(tmp_result); std::vector<double> tmp_result(num_class_);
objective_function_->ConvertOutput(tree_pred.data(), tmp_result.data());
for (int j = 0; j < num_class_; ++j) { for (int j = 0; j < num_class_; ++j) {
out_result[j * num_data + i] = static_cast<double>(tmp_result[j]); out_result[j * num_data + i] = static_cast<double>(tmp_result[j]);
} }
...@@ -638,12 +636,9 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) { ...@@ -638,12 +636,9 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
} else { } else {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
std::vector<double> tmp_result(num_class_); std::vector<double> tmp_result(num_tree_per_iteration_);
for (int j = 0; j < num_tree_per_iteration_; ++j) { for (int j = 0; j < num_tree_per_iteration_; ++j) {
tmp_result[j] = raw_scores[j * num_data + i]; out_result[j * num_data + i] = static_cast<double>(raw_scores[j * num_data + i]);
}
for (int j = 0; j < num_class_; ++j) {
out_result[j * num_data + i] = static_cast<double>(tmp_result[j]);
} }
} }
} }
...@@ -875,38 +870,57 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const { ...@@ -875,38 +870,57 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
return pairs; return pairs;
} }
std::vector<double> GBDT::PredictRaw(const double* value) const {
std::vector<double> ret(num_tree_per_iteration_, 0.0f);
void GBDT::PredictRaw(const double* value, double* output) const {
if (num_threads_ <= num_tree_per_iteration_) {
#pragma omp parallel for schedule(static)
for (int k = 0; k < num_tree_per_iteration_; ++k) {
for (int i = 0; i < num_iteration_for_pred_; ++i) { for (int i = 0; i < num_iteration_for_pred_; ++i) {
for (int j = 0; j < num_tree_per_iteration_; ++j) { output[k] += models_[i * num_tree_per_iteration_ + k]->Predict(value);
ret[j] += models_[i * num_tree_per_iteration_ + j]->Predict(value); }
}
} else {
for (int k = 0; k < num_tree_per_iteration_; ++k) {
double t = 0.0f;
#pragma omp parallel for schedule(static) reduction(+:t)
for (int i = 0; i < num_iteration_for_pred_; ++i) {
t += models_[i * num_tree_per_iteration_ + k]->Predict(value);
}
output[k] = t;
} }
} }
return ret;
} }
std::vector<double> GBDT::Predict(const double* value) const { void GBDT::Predict(const double* value, double* output) const {
std::vector<double> ret(num_tree_per_iteration_, 0.0f); if (num_threads_ <= num_tree_per_iteration_) {
#pragma omp parallel for schedule(static)
for (int k = 0; k < num_tree_per_iteration_; ++k) {
for (int i = 0; i < num_iteration_for_pred_; ++i) { for (int i = 0; i < num_iteration_for_pred_; ++i) {
for (int j = 0; j < num_tree_per_iteration_; ++j) { output[k] += models_[i * num_tree_per_iteration_ + k]->Predict(value);
ret[j] += models_[i * num_tree_per_iteration_ + j]->Predict(value);
} }
} }
if (objective_function_ != nullptr) {
return objective_function_->ConvertOutput(ret);
} else { } else {
return ret; for (int k = 0; k < num_tree_per_iteration_; ++k) {
double t = 0.0f;
#pragma omp parallel for schedule(static) reduction(+:t)
for (int i = 0; i < num_iteration_for_pred_; ++i) {
t += models_[i * num_tree_per_iteration_ + k]->Predict(value);
}
output[k] = t;
}
}
if (objective_function_ != nullptr) {
objective_function_->ConvertOutput(output, output);
} }
} }
std::vector<int> GBDT::PredictLeafIndex(const double* value) const { void GBDT::PredictLeafIndex(const double* value, double* output) const {
std::vector<int> ret; int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;
for (int i = 0; i < num_iteration_for_pred_; ++i) { #pragma omp parallel for schedule(static)
for (int j = 0; j < num_tree_per_iteration_; ++j) { for (int i = 0; i < total_tree; ++i) {
ret.push_back(models_[i * num_tree_per_iteration_ + j]->PredictLeafIndex(value)); output[i] = models_[i]->PredictLeafIndex(value);
}
} }
return ret;
} }
} // namespace LightGBM } // namespace LightGBM
...@@ -50,13 +50,13 @@ public: ...@@ -50,13 +50,13 @@ public:
auto new_tree = std::unique_ptr<Tree>(new Tree(*(tree.get()))); auto new_tree = std::unique_ptr<Tree>(new Tree(*(tree.get())));
models_.push_back(std::move(new_tree)); models_.push_back(std::move(new_tree));
} }
num_init_iteration_ = static_cast<int>(models_.size()) / num_class_; num_init_iteration_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
// push model in current object // push model in current object
for (const auto& tree : original_models) { for (const auto& tree : original_models) {
auto new_tree = std::unique_ptr<Tree>(new Tree(*(tree.get()))); auto new_tree = std::unique_ptr<Tree>(new Tree(*(tree.get())));
models_.push_back(std::move(new_tree)); models_.push_back(std::move(new_tree));
} }
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_; num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
} }
/*! /*!
...@@ -88,7 +88,7 @@ public: ...@@ -88,7 +88,7 @@ public:
*/ */
void RollbackOneIter() override; void RollbackOneIter() override;
int GetCurrentIteration() const override { return static_cast<int>(models_.size()) / num_class_; } int GetCurrentIteration() const override { return static_cast<int>(models_.size()) / num_tree_per_iteration_; }
bool EvalAndCheckEarlyStopping() override; bool EvalAndCheckEarlyStopping() override;
...@@ -122,26 +122,24 @@ public: ...@@ -122,26 +122,24 @@ public:
*/ */
void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) override; void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) override;
/*! inline int NumPredictOneRow(int num_iteration, int is_pred_leaf) const override {
* \brief Prediction for one record without sigmoid transformation int num_preb_in_one_row = num_class_;
* \param feature_values Feature value on this record if (is_pred_leaf) {
* \return Prediction result for this record int max_iteration = GetCurrentIteration();
*/ if (num_iteration > 0) {
std::vector<double> PredictRaw(const double* feature_values) const override; num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
} else {
num_preb_in_one_row *= max_iteration;
}
}
return num_preb_in_one_row;
}
/*! void PredictRaw(const double* feature_values, double* output) const override;
* \brief Prediction for one record with sigmoid transformation if enabled
* \param feature_values Feature value on this record
* \return Prediction result for this record
*/
std::vector<double> Predict(const double* feature_values) const override;
/*! void Predict(const double* feature_values, double* output) const override;
* \brief Prediction for one record with leaf index
* \param feature_values Feature value on this record void PredictLeafIndex(const double* value, double* output) const override;
* \return Predicted leaf index for this record
*/
std::vector<int> PredictLeafIndex(const double* value) const override;
/*! /*!
* \brief Dump model to json format string * \brief Dump model to json format string
...@@ -193,20 +191,51 @@ public: ...@@ -193,20 +191,51 @@ public:
*/ */
inline int NumberOfTotalModel() const override { return static_cast<int>(models_.size()); } inline int NumberOfTotalModel() const override { return static_cast<int>(models_.size()); }
/*!
* \brief Get number of tree per iteration
* \return number of tree per iteration
*/
inline int NumTreePerIteration() const override { return num_tree_per_iteration_; }
/*! /*!
* \brief Get number of classes * \brief Get number of classes
* \return Number of classes * \return Number of classes
*/ */
inline int NumberOfClasses() const override { return num_class_; } inline int NumberOfClasses() const override { return num_class_; }
/*! inline std::vector<int> InitPredict(int num_iteration) override {
* \brief Set number of iterations for prediction num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
*/
inline void SetNumIterationForPred(int num_iteration) override {
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
if (num_iteration > 0) { if (num_iteration > 0) {
num_iteration_for_pred_ = std::min(num_iteration + (boost_from_average_ ? 1 : 0), num_iteration_for_pred_); num_iteration_for_pred_ = std::min(num_iteration + (boost_from_average_ ? 1 : 0), num_iteration_for_pred_);
} }
int used_fidx = 0;
// Construct used feature mapper
std::vector<int> feature_mapper(max_feature_idx_ + 1, -1);
int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;
#pragma omp parallel for schedule(static, 64) if (total_tree >= 128)
for (int i = 0; i < total_tree; ++i) {
int num_leaves = models_[i]->num_leaves();
for (int j = 0; j < num_leaves - 1; ++j) {
int fidx = models_[i]->split_feature(j);
if (feature_mapper[fidx] == -1) {
#pragma omp critical
{
if (feature_mapper[fidx] == -1) {
feature_mapper[fidx] = used_fidx;
++used_fidx;
}
}
}
}
}
#pragma omp parallel for schedule(static, 64) if (total_tree >= 128)
for (int i = 0; i < total_tree; ++i) {
models_[i]->ReMapFeature(feature_mapper);
}
return feature_mapper;
} }
inline double GetLeafValue(int tree_idx, int leaf_idx) const { inline double GetLeafValue(int tree_idx, int leaf_idx) const {
......
...@@ -160,9 +160,10 @@ public: ...@@ -160,9 +160,10 @@ public:
boosting_->RollbackOneIter(); boosting_->RollbackOneIter();
} }
Predictor NewPredictor(int num_iteration, int predict_type) { void Predict(int num_iteration, int predict_type, int nrow,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
double* out_result, int64_t* out_len) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
boosting_->SetNumIterationForPred(num_iteration);
bool is_predict_leaf = false; bool is_predict_leaf = false;
bool is_raw_score = false; bool is_raw_score = false;
if (predict_type == C_API_PREDICT_LEAF_INDEX) { if (predict_type == C_API_PREDICT_LEAF_INDEX) {
...@@ -172,9 +173,33 @@ public: ...@@ -172,9 +173,33 @@ public:
} else { } else {
is_raw_score = false; is_raw_score = false;
} }
// not threading safe now Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf);
// boosting_->SetNumIterationForPred may be set by other thread during prediction. int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf);
return Predictor(boosting_.get(), is_raw_score, is_predict_leaf); auto pred_fun = predictor.GetPredictFunction();
auto pred_wrt_ptr = out_result;
for (int i = 0; i < nrow; ++i) {
auto one_row = get_row_fun(i);
pred_fun(one_row, pred_wrt_ptr);
pred_wrt_ptr += num_preb_in_one_row;
}
*out_len = nrow * num_preb_in_one_row;
}
void Predict(int num_iteration, int predict_type, const char* data_filename,
int data_has_header, const char* result_filename) {
std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false;
bool is_raw_score = false;
if (predict_type == C_API_PREDICT_LEAF_INDEX) {
is_predict_leaf = true;
} else if (predict_type == C_API_PREDICT_RAW_SCORE) {
is_raw_score = true;
} else {
is_raw_score = false;
}
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf);
bool bool_data_has_header = data_has_header > 0 ? true : false;
predictor.Predict(data_filename, result_filename, bool_data_has_header);
} }
void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) { void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
...@@ -288,11 +313,11 @@ private: ...@@ -288,11 +313,11 @@ private:
// start of c_api functions // start of c_api functions
LIGHTGBM_C_EXPORT const char* LGBM_GetLastError() { const char* LGBM_GetLastError() {
return LastErrorMsg(); return LastErrorMsg();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename, int LGBM_DatasetCreateFromFile(const char* filename,
const char* parameters, const char* parameters,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out) { DatasetHandle* out) {
...@@ -311,7 +336,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename, ...@@ -311,7 +336,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledColumn(double** sample_data, int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
int** sample_indices, int** sample_indices,
int32_t ncol, int32_t ncol,
const int* num_per_col, const int* num_per_col,
...@@ -331,7 +356,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledColumn(double** sample_data, ...@@ -331,7 +356,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle reference, int LGBM_DatasetCreateByReference(const DatasetHandle reference,
int64_t num_total_row, int64_t num_total_row,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
...@@ -342,7 +367,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc ...@@ -342,7 +367,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset, int LGBM_DatasetPushRows(DatasetHandle dataset,
const void* data, const void* data,
int data_type, int data_type,
int32_t nrow, int32_t nrow,
...@@ -352,7 +377,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset, ...@@ -352,7 +377,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
auto p_dataset = reinterpret_cast<Dataset*>(dataset); auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1); auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) { for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
...@@ -367,7 +392,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset, ...@@ -367,7 +392,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRows(DatasetHandle dataset,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset, int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
const void* indptr, const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
...@@ -382,7 +407,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset, ...@@ -382,7 +407,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int32_t nrow = static_cast<int32_t>(nindptr - 1); int32_t nrow = static_cast<int32_t>(nindptr - 1);
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) { for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
...@@ -398,7 +423,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset, ...@@ -398,7 +423,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data, int LGBM_DatasetCreateFromMat(const void* data,
int data_type, int data_type,
int32_t nrow, int32_t nrow,
int32_t ncol, int32_t ncol,
...@@ -442,7 +467,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data, ...@@ -442,7 +467,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
reinterpret_cast<const Dataset*>(reference)); reinterpret_cast<const Dataset*>(reference));
} }
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) { for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
...@@ -456,7 +481,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data, ...@@ -456,7 +481,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr, int LGBM_DatasetCreateFromCSR(const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
...@@ -509,7 +534,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -509,7 +534,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
reinterpret_cast<const Dataset*>(reference)); reinterpret_cast<const Dataset*>(reference));
} }
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < nindptr - 1; ++i) { for (int i = 0; i < nindptr - 1; ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
...@@ -523,7 +548,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -523,7 +548,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, int LGBM_DatasetCreateFromCSC(const void* col_ptr,
int col_ptr_type, int col_ptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
...@@ -575,7 +600,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, ...@@ -575,7 +600,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
reinterpret_cast<const Dataset*>(reference)); reinterpret_cast<const Dataset*>(reference));
} }
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < ncol_ptr - 1; ++i) { for (int i = 0; i < ncol_ptr - 1; ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
...@@ -600,7 +625,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, ...@@ -600,7 +625,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset( int LGBM_DatasetGetSubset(
const DatasetHandle handle, const DatasetHandle handle,
const int32_t* used_row_indices, const int32_t* used_row_indices,
int32_t num_used_row_indices, int32_t num_used_row_indices,
...@@ -618,7 +643,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset( ...@@ -618,7 +643,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetSubset(
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames( int LGBM_DatasetSetFeatureNames(
DatasetHandle handle, DatasetHandle handle,
const char** feature_names, const char** feature_names,
int num_feature_names) { int num_feature_names) {
...@@ -632,7 +657,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames( ...@@ -632,7 +657,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetFeatureNames(
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames( int LGBM_DatasetGetFeatureNames(
DatasetHandle handle, DatasetHandle handle,
char** feature_names, char** feature_names,
int* num_feature_names) { int* num_feature_names) {
...@@ -646,13 +671,13 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames( ...@@ -646,13 +671,13 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetFeatureNames(
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetFree(DatasetHandle handle) { int LGBM_DatasetFree(DatasetHandle handle) {
API_BEGIN(); API_BEGIN();
delete reinterpret_cast<Dataset*>(handle); delete reinterpret_cast<Dataset*>(handle);
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle, int LGBM_DatasetSaveBinary(DatasetHandle handle,
const char* filename) { const char* filename) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
...@@ -660,7 +685,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle, ...@@ -660,7 +685,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSaveBinary(DatasetHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle, int LGBM_DatasetSetField(DatasetHandle handle,
const char* field_name, const char* field_name,
const void* field_data, const void* field_data,
int num_element, int num_element,
...@@ -679,7 +704,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle, ...@@ -679,7 +704,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle, int LGBM_DatasetGetField(DatasetHandle handle,
const char* field_name, const char* field_name,
int* out_len, int* out_len,
const void** out_ptr, const void** out_ptr,
...@@ -702,7 +727,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle, ...@@ -702,7 +727,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetField(DatasetHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle, int LGBM_DatasetGetNumData(DatasetHandle handle,
int* out) { int* out) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
...@@ -710,7 +735,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle, ...@@ -710,7 +735,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumData(DatasetHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle, int LGBM_DatasetGetNumFeature(DatasetHandle handle,
int* out) { int* out) {
API_BEGIN(); API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
...@@ -720,7 +745,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle, ...@@ -720,7 +745,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetGetNumFeature(DatasetHandle handle,
// ---- start of booster // ---- start of booster
LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data, int LGBM_BoosterCreate(const DatasetHandle train_data,
const char* parameters, const char* parameters,
BoosterHandle* out) { BoosterHandle* out) {
API_BEGIN(); API_BEGIN();
...@@ -730,7 +755,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data, ...@@ -730,7 +755,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCreate(const DatasetHandle train_data,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile( int LGBM_BoosterCreateFromModelfile(
const char* filename, const char* filename,
int* out_num_iterations, int* out_num_iterations,
BoosterHandle* out) { BoosterHandle* out) {
...@@ -741,7 +766,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile( ...@@ -741,7 +766,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCreateFromModelfile(
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString( int LGBM_BoosterLoadModelFromString(
const char* model_str, const char* model_str,
int* out_num_iterations, int* out_num_iterations,
BoosterHandle* out) { BoosterHandle* out) {
...@@ -753,13 +778,13 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString( ...@@ -753,13 +778,13 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle) { int LGBM_BoosterFree(BoosterHandle handle) {
API_BEGIN(); API_BEGIN();
delete reinterpret_cast<Booster*>(handle); delete reinterpret_cast<Booster*>(handle);
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle, int LGBM_BoosterMerge(BoosterHandle handle,
BoosterHandle other_handle) { BoosterHandle other_handle) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
...@@ -768,7 +793,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle, ...@@ -768,7 +793,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterMerge(BoosterHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle, int LGBM_BoosterAddValidData(BoosterHandle handle,
const DatasetHandle valid_data) { const DatasetHandle valid_data) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
...@@ -777,7 +802,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle, ...@@ -777,7 +802,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterAddValidData(BoosterHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle, int LGBM_BoosterResetTrainingData(BoosterHandle handle,
const DatasetHandle train_data) { const DatasetHandle train_data) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
...@@ -786,21 +811,21 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle, ...@@ -786,21 +811,21 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterResetTrainingData(BoosterHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) { int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->ResetConfig(parameters); ref_booster->ResetConfig(parameters);
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) { int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetBoosting()->NumberOfClasses(); *out_len = ref_booster->GetBoosting()->NumberOfClasses();
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) { int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
if (ref_booster->TrainOneIter()) { if (ref_booster->TrainOneIter()) {
...@@ -811,7 +836,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_fi ...@@ -811,7 +836,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_fi
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
const float* grad, const float* grad,
const float* hess, const float* hess,
int* is_finished) { int* is_finished) {
...@@ -825,49 +850,49 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, ...@@ -825,49 +850,49 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle) { int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->RollbackOneIter(); ref_booster->RollbackOneIter();
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) { int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_iteration = ref_booster->GetBoosting()->GetCurrentIteration(); *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) { int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetEvalCounts(); *out_len = ref_booster->GetEvalCounts();
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) { int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetEvalNames(out_strs); *out_len = ref_booster->GetEvalNames(out_strs);
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) { int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetFeatureNames(out_strs); *out_len = ref_booster->GetFeatureNames(out_strs);
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) { int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1; *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle, int LGBM_BoosterGetEval(BoosterHandle handle,
int data_idx, int data_idx,
int* out_len, int* out_len,
double* out_results) { double* out_results) {
...@@ -882,7 +907,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle, ...@@ -882,7 +907,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetEval(BoosterHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle, int LGBM_BoosterGetNumPredict(BoosterHandle handle,
int data_idx, int data_idx,
int64_t* out_len) { int64_t* out_len) {
API_BEGIN(); API_BEGIN();
...@@ -891,7 +916,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle, ...@@ -891,7 +916,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumPredict(BoosterHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle, int LGBM_BoosterGetPredict(BoosterHandle handle,
int data_idx, int data_idx,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
...@@ -901,7 +926,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle, ...@@ -901,7 +926,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetPredict(BoosterHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle, int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* data_filename, const char* data_filename,
int data_has_header, int data_has_header,
int predict_type, int predict_type,
...@@ -909,37 +934,23 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle, ...@@ -909,37 +934,23 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
const char* result_filename) { const char* result_filename) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type); ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header, result_filename);
bool bool_data_has_header = data_has_header > 0 ? true : false;
predictor.Predict(data_filename, result_filename, bool_data_has_header);
API_END(); API_END();
} }
int64_t GetNumPredOneRow(const Booster* ref_booster, int predict_type, int64_t num_iteration) { int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
int64_t num_preb_in_one_row = ref_booster->GetBoosting()->NumberOfClasses();
if (predict_type == C_API_PREDICT_LEAF_INDEX) {
int64_t max_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
if (num_iteration > 0) {
num_preb_in_one_row *= static_cast<int>(std::min(max_iteration, num_iteration));
} else {
num_preb_in_one_row *= max_iteration;
}
}
return num_preb_in_one_row;
}
LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
int num_row, int num_row,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
int64_t* out_len) { int64_t* out_len) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = static_cast<int64_t>(num_row * GetNumPredOneRow(ref_booster, predict_type, num_iteration)); *out_len = static_cast<int64_t>(num_row * ref_booster->GetBoosting()->NumPredictOneRow(
num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX));
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, int LGBM_BoosterPredictForCSR(BoosterHandle handle,
const void* indptr, const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
...@@ -954,27 +965,13 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -954,27 +965,13 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
int nrow = static_cast<int>(nindptr - 1); int nrow = static_cast<int>(nindptr - 1);
OMP_INIT_EX(); ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, out_result, out_len);
#pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN();
auto one_row = get_row_fun(i);
auto predicton_result = predictor.GetPredictFunction()(one_row);
for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
*out_len = nrow * num_preb_in_one_row;
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, int LGBM_BoosterPredictForCSC(BoosterHandle handle,
const void* col_ptr, const void* col_ptr,
int col_ptr_type, int col_ptr_type,
const int32_t* indices, const int32_t* indices,
...@@ -989,37 +986,27 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -989,37 +986,27 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration);
int ncol = static_cast<int>(ncol_ptr - 1); int ncol = static_cast<int>(ncol_ptr - 1);
Threading::For<int64_t>(0, static_cast<int64_t>(num_row),
[&predictor, &out_result, num_preb_in_one_row, ncol, col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem]
(int, int64_t start, int64_t end) {
std::vector<CSC_RowIterator> iterators; std::vector<CSC_RowIterator> iterators;
for (int j = 0; j < ncol; ++j) { for (int j = 0; j < ncol; ++j) {
iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j); iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
} }
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
[&iterators, ncol](int i) {
std::vector<std::pair<int, double>> one_row; std::vector<std::pair<int, double>> one_row;
for (int64_t i = start; i < end; ++i) {
one_row.clear();
for (int j = 0; j < ncol; ++j) { for (int j = 0; j < ncol; ++j) {
auto val = iterators[j].Get(static_cast<int>(i)); auto val = iterators[j].Get(i);
if (std::fabs(val) > kEpsilon) { if (std::fabs(val) > kEpsilon) {
one_row.emplace_back(j, val); one_row.emplace_back(j, val);
} }
} }
auto predicton_result = predictor.GetPredictFunction()(one_row); return one_row;
for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) { };
out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]); ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, out_result, out_len);
}
}
});
*out_len = num_row * num_preb_in_one_row;
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle, int LGBM_BoosterPredictForMat(BoosterHandle handle,
const void* data, const void* data,
int data_type, int data_type,
int32_t nrow, int32_t nrow,
...@@ -1031,26 +1018,12 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -1031,26 +1018,12 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto predictor = ref_booster->NewPredictor(static_cast<int>(num_iteration), predict_type);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
int64_t num_preb_in_one_row = GetNumPredOneRow(ref_booster, predict_type, num_iteration); ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, out_result, out_len);
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN();
auto one_row = get_row_fun(i);
auto predicton_result = predictor.GetPredictFunction()(one_row);
for (int j = 0; j < static_cast<int>(predicton_result.size()); ++j) {
out_result[i * num_preb_in_one_row + j] = static_cast<double>(predicton_result[j]);
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
*out_len = nrow * num_preb_in_one_row;
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle, int LGBM_BoosterSaveModel(BoosterHandle handle,
int num_iteration, int num_iteration,
const char* filename) { const char* filename) {
API_BEGIN(); API_BEGIN();
...@@ -1059,7 +1032,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle, ...@@ -1059,7 +1032,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle, int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int num_iteration, int num_iteration,
int buffer_len, int buffer_len,
int* out_len, int* out_len,
...@@ -1074,7 +1047,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle, ...@@ -1074,7 +1047,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle, int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration, int num_iteration,
int buffer_len, int buffer_len,
int* out_len, int* out_len,
...@@ -1089,7 +1062,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle, ...@@ -1089,7 +1062,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle, int LGBM_BoosterGetLeafValue(BoosterHandle handle,
int tree_idx, int tree_idx,
int leaf_idx, int leaf_idx,
double* out_val) { double* out_val) {
...@@ -1099,7 +1072,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle, ...@@ -1099,7 +1072,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetLeafValue(BoosterHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle, int LGBM_BoosterSetLeafValue(BoosterHandle handle,
int tree_idx, int tree_idx,
int leaf_idx, int leaf_idx,
double val) { double val) {
...@@ -1116,7 +1089,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1116,7 +1089,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
if (data_type == C_API_DTYPE_FLOAT32) { if (data_type == C_API_DTYPE_FLOAT32) {
const float* data_ptr = reinterpret_cast<const float*>(data); const float* data_ptr = reinterpret_cast<const float*>(data);
if (is_row_major) { if (is_row_major) {
return [data_ptr, num_col, num_row] (int row_idx) { return [data_ptr, num_col, num_row](int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx; auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
...@@ -1125,7 +1098,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1125,7 +1098,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
return ret; return ret;
}; };
} else { } else {
return [data_ptr, num_col, num_row] (int row_idx) { return [data_ptr, num_col, num_row](int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx)); ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
...@@ -1136,7 +1109,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1136,7 +1109,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
} else if (data_type == C_API_DTYPE_FLOAT64) { } else if (data_type == C_API_DTYPE_FLOAT64) {
const double* data_ptr = reinterpret_cast<const double*>(data); const double* data_ptr = reinterpret_cast<const double*>(data);
if (is_row_major) { if (is_row_major) {
return [data_ptr, num_col, num_row] (int row_idx) { return [data_ptr, num_col, num_row](int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx; auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
...@@ -1145,7 +1118,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1145,7 +1118,7 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
return ret; return ret;
}; };
} else { } else {
return [data_ptr, num_col, num_row] (int row_idx) { return [data_ptr, num_col, num_row](int row_idx) {
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx)); ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
...@@ -1161,7 +1134,7 @@ std::function<std::vector<std::pair<int, double>>(int row_idx)> ...@@ -1161,7 +1134,7 @@ std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) { RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major); auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
if (inner_function != nullptr) { if (inner_function != nullptr) {
return [inner_function] (int row_idx) { return [inner_function](int row_idx) {
auto raw_values = inner_function(row_idx); auto raw_values = inner_function(row_idx);
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) { for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
...@@ -1181,7 +1154,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1181,7 +1154,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const float* data_ptr = reinterpret_cast<const float*>(data); const float* data_ptr = reinterpret_cast<const float*>(data);
if (indptr_type == C_API_DTYPE_INT32) { if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr); const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1192,7 +1165,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1192,7 +1165,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
}; };
} else if (indptr_type == C_API_DTYPE_INT64) { } else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr); const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1206,7 +1179,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1206,7 +1179,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
const double* data_ptr = reinterpret_cast<const double*>(data); const double* data_ptr = reinterpret_cast<const double*>(data);
if (indptr_type == C_API_DTYPE_INT32) { if (indptr_type == C_API_DTYPE_INT32) {
const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr); const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1217,7 +1190,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1217,7 +1190,7 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
}; };
} else if (indptr_type == C_API_DTYPE_INT64) { } else if (indptr_type == C_API_DTYPE_INT64) {
const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr); const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) { return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
std::vector<std::pair<int, double>> ret; std::vector<std::pair<int, double>> ret;
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
...@@ -1240,7 +1213,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1240,7 +1213,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr); const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
...@@ -1253,7 +1226,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1253,7 +1226,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr); const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
...@@ -1269,7 +1242,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1269,7 +1242,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr); const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
...@@ -1282,7 +1255,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1282,7 +1255,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr); const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
int64_t start = ptr_col_ptr[col_idx]; int64_t start = ptr_col_ptr[col_idx];
int64_t end = ptr_col_ptr[col_idx + 1]; int64_t end = ptr_col_ptr[col_idx + 1];
return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) { return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
int64_t i = static_cast<int64_t>(start + bias); int64_t i = static_cast<int64_t>(start + bias);
if (i >= end) { if (i >= end) {
return std::make_pair(-1, 0.0); return std::make_pair(-1, 0.0);
......
...@@ -888,7 +888,8 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat ...@@ -888,7 +888,8 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat
// parser // parser
parser->ParseOneLine(text_data[i].c_str(), &oneline_features, &tmp_label); parser->ParseOneLine(text_data[i].c_str(), &oneline_features, &tmp_label);
// set initial score // set initial score
std::vector<double> oneline_init_score = predict_fun_(oneline_features); std::vector<double> oneline_init_score(num_class_);
predict_fun_(oneline_features, oneline_init_score.data());
for (int k = 0; k < num_class_; ++k) { for (int k = 0; k < num_class_; ++k) {
init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]); init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
} }
...@@ -947,7 +948,8 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* ...@@ -947,7 +948,8 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label); parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
// set initial score // set initial score
if (!init_score.empty()) { if (!init_score.empty()) {
std::vector<double> oneline_init_score = predict_fun_(oneline_features); std::vector<double> oneline_init_score(num_class_);
predict_fun_(oneline_features, oneline_init_score.data());
for (int k = 0; k < num_class_; ++k) { for (int k = 0; k < num_class_; ++k) {
init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]); init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
} }
......
...@@ -318,6 +318,7 @@ std::string Tree::ToString() { ...@@ -318,6 +318,7 @@ std::string Tree::ToString() {
str_buf << "internal_count=" str_buf << "internal_count="
<< Common::ArrayToString<data_size_t>(internal_count_, num_leaves_ - 1, ' ') << std::endl; << Common::ArrayToString<data_size_t>(internal_count_, num_leaves_ - 1, ' ') << std::endl;
str_buf << "shrinkage=" << shrinkage_ << std::endl; str_buf << "shrinkage=" << shrinkage_ << std::endl;
str_buf << "has_categorical=" << (has_categorical_ ? 1 : 0) << std::endl;
str_buf << std::endl; str_buf << std::endl;
return str_buf.str(); return str_buf.str();
} }
...@@ -327,6 +328,7 @@ std::string Tree::ToJSON() { ...@@ -327,6 +328,7 @@ std::string Tree::ToJSON() {
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2); str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
str_buf << "\"num_leaves\":" << num_leaves_ << "," << std::endl; str_buf << "\"num_leaves\":" << num_leaves_ << "," << std::endl;
str_buf << "\"shrinkage\":" << shrinkage_ << "," << std::endl; str_buf << "\"shrinkage\":" << shrinkage_ << "," << std::endl;
str_buf << "\"has_categorical\":" << (has_categorical_ ? 1 : 0) << "," << std::endl;
str_buf << "\"tree_structure\":" << NodeToJSON(0) << std::endl; str_buf << "\"tree_structure\":" << NodeToJSON(0) << std::endl;
return str_buf.str(); return str_buf.str();
...@@ -454,6 +456,15 @@ Tree::Tree(const std::string& str) { ...@@ -454,6 +456,15 @@ Tree::Tree(const std::string& str) {
} else { } else {
shrinkage_ = 1.0f; shrinkage_ = 1.0f;
} }
if (key_vals.count("has_categorical")) {
int t = 0;
Common::Atoi(key_vals["has_categorical"].c_str(), &t);
has_categorical_ = t > 0;
} else {
has_categorical_ = false;
}
} }
} // namespace LightGBM } // namespace LightGBM
...@@ -54,8 +54,7 @@ public: ...@@ -54,8 +54,7 @@ public:
return -1.0f; return -1.0f;
} }
std::vector<double> Eval(const double* score, const ObjectiveFunction* objective, std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const override {
int) const override {
double sum_loss = 0.0f; double sum_loss = 0.0f;
if (objective == nullptr) { if (objective == nullptr) {
if (weights_ == nullptr) { if (weights_ == nullptr) {
...@@ -75,15 +74,16 @@ public: ...@@ -75,15 +74,16 @@ public:
if (weights_ == nullptr) { if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
double prob = objective->ConvertOutput(score[i]); double prob = 0;
objective->ConvertOutput(&score[i], &prob);
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob); sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob);
} }
} else { } else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// sigmoid transform double prob = 0;
double prob = objective->ConvertOutput(score[i]); objective->ConvertOutput(&score[i], &prob);
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob) * weights_[i]; sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], prob) * weights_[i];
} }
...@@ -189,8 +189,7 @@ public: ...@@ -189,8 +189,7 @@ public:
} }
} }
std::vector<double> Eval(const double* score, const ObjectiveFunction*, std::vector<double> Eval(const double* score, const ObjectiveFunction*) const override {
int) const override {
// get indices sorted by score, descent order // get indices sorted by score, descent order
std::vector<data_size_t> sorted_idx; std::vector<data_size_t> sorted_idx;
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
......
...@@ -93,8 +93,7 @@ public: ...@@ -93,8 +93,7 @@ public:
cur_left = cur_k; cur_left = cur_k;
} }
} }
std::vector<double> Eval(const double* score, const ObjectiveFunction*, std::vector<double> Eval(const double* score, const ObjectiveFunction*) const override {
int) const override {
// some buffers for multi-threading sum up // some buffers for multi-threading sum up
std::vector<std::vector<double>> result_buffer_; std::vector<std::vector<double>> result_buffer_;
for (int i = 0; i < num_threads_; ++i) { for (int i = 0; i < num_threads_; ++i) {
......
...@@ -15,8 +15,8 @@ namespace LightGBM { ...@@ -15,8 +15,8 @@ namespace LightGBM {
template<typename PointWiseLossCalculator> template<typename PointWiseLossCalculator>
class MulticlassMetric: public Metric { class MulticlassMetric: public Metric {
public: public:
explicit MulticlassMetric(const MetricConfig&) { explicit MulticlassMetric(const MetricConfig& config) {
num_class_ = config.num_class;
} }
virtual ~MulticlassMetric() { virtual ~MulticlassMetric() {
...@@ -49,31 +49,38 @@ public: ...@@ -49,31 +49,38 @@ public:
return -1.0f; return -1.0f;
} }
std::vector<double> Eval(const double* score, const ObjectiveFunction* objective, std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const override {
int num_tree_per_iteration) const override {
double sum_loss = 0.0; double sum_loss = 0.0;
int num_tree_per_iteration = num_class_;
int num_pred_per_row = num_class_;
if (objective != nullptr) {
num_tree_per_iteration = objective->NumTreePerIteration();
num_pred_per_row = objective->NumPredictOneRow();
}
if (objective != nullptr) { if (objective != nullptr) {
if (weights_ == nullptr) { if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<double> rec(num_tree_per_iteration); std::vector<double> raw_score(num_tree_per_iteration);
for (int k = 0; k < num_tree_per_iteration; ++k) { for (int k = 0; k < num_tree_per_iteration; ++k) {
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
rec[k] = static_cast<double>(score[idx]); raw_score[k] = static_cast<double>(score[idx]);
} }
rec = objective->ConvertOutput(rec); std::vector<double> rec(num_pred_per_row);
objective->ConvertOutput(raw_score.data(), rec.data());
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec); sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec);
} }
} else { } else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<double> rec(num_tree_per_iteration); std::vector<double> raw_score(num_tree_per_iteration);
for (int k = 0; k < num_tree_per_iteration; ++k) { for (int k = 0; k < num_tree_per_iteration; ++k) {
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
rec[k] = static_cast<double>(score[idx]); raw_score[k] = static_cast<double>(score[idx]);
} }
rec = objective->ConvertOutput(rec); std::vector<double> rec(num_pred_per_row);
objective->ConvertOutput(raw_score.data(), rec.data());
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i]; sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i];
} }
...@@ -118,6 +125,7 @@ private: ...@@ -118,6 +125,7 @@ private:
double sum_weights_; double sum_weights_;
/*! \brief Name of this test set */ /*! \brief Name of this test set */
std::vector<std::string> name_; std::vector<std::string> name_;
int num_class_;
}; };
/*! \brief L2 loss for multiclass task */ /*! \brief L2 loss for multiclass task */
......
...@@ -82,7 +82,7 @@ public: ...@@ -82,7 +82,7 @@ public:
return 1.0f; return 1.0f;
} }
std::vector<double> Eval(const double* score, const ObjectiveFunction*, int) const override { std::vector<double> Eval(const double* score, const ObjectiveFunction*) const override {
// some buffers for multi-threading sum up // some buffers for multi-threading sum up
std::vector<std::vector<double>> result_buffer_; std::vector<std::vector<double>> result_buffer_;
for (int i = 0; i < num_threads_; ++i) { for (int i = 0; i < num_threads_; ++i) {
......
...@@ -48,7 +48,7 @@ public: ...@@ -48,7 +48,7 @@ public:
} }
} }
std::vector<double> Eval(const double* score, const ObjectiveFunction* objective, int) const override { std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const override {
double sum_loss = 0.0f; double sum_loss = 0.0f;
if (objective == nullptr) { if (objective == nullptr) {
if (weights_ == nullptr) { if (weights_ == nullptr) {
...@@ -69,13 +69,17 @@ public: ...@@ -69,13 +69,17 @@ public:
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], objective->ConvertOutput(score[i]), huber_delta_, fair_c_); double t = 0;
objective->ConvertOutput(&score[i], &t);
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], t, huber_delta_, fair_c_);
} }
} else { } else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss) #pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// add loss // add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], objective->ConvertOutput(score[i]), huber_delta_, fair_c_) * weights_[i]; double t = 0;
objective->ConvertOutput(&score[i], &t);
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], t, huber_delta_, fair_c_) * weights_[i];
} }
} }
} }
......
...@@ -116,13 +116,8 @@ public: ...@@ -116,13 +116,8 @@ public:
return "binary"; return "binary";
} }
std::vector<double> ConvertOutput(std::vector<double>& input) const override { void ConvertOutput(const double* input, double* output) const override {
input[0] = 1.0f / (1.0f + std::exp(-sigmoid_ * input[0])); output[0] = 1.0f / (1.0f + std::exp(-sigmoid_ * input[0]));
return input;
}
double ConvertOutput(double input) const override {
return 1.0f / (1.0f + std::exp(-sigmoid_ * input));
} }
std::string ToString() const override { std::string ToString() const override {
......
...@@ -113,9 +113,8 @@ public: ...@@ -113,9 +113,8 @@ public:
} }
} }
std::vector<double> ConvertOutput(std::vector<double>& input) const override { void ConvertOutput(const double* input, double* output) const override {
Common::Softmax(input.data(), num_class_); Common::Softmax(input, output, num_class_);
return input;
} }
const char* GetName() const override { const char* GetName() const override {
...@@ -131,7 +130,9 @@ public: ...@@ -131,7 +130,9 @@ public:
bool SkipEmptyClass() const override { return true; } bool SkipEmptyClass() const override { return true; }
int numTreePerIteration() const override { return num_class_; } int NumTreePerIteration() const override { return num_class_; }
int NumPredictOneRow() const override { return num_class_; }
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
...@@ -206,11 +207,10 @@ public: ...@@ -206,11 +207,10 @@ public:
return "multiclassova"; return "multiclassova";
} }
std::vector<double> ConvertOutput(std::vector<double>& input) const override { void ConvertOutput(const double* input, double* output) const override {
for (int i = 0; i < num_class_; ++i) { for (int i = 0; i < num_class_; ++i) {
input[i] = 1.0f / (1.0f + std::exp(-sigmoid_ * input[i])); output[i] = 1.0f / (1.0f + std::exp(-sigmoid_ * input[i]));
} }
return input;
} }
std::string ToString() const override { std::string ToString() const override {
...@@ -223,7 +223,9 @@ public: ...@@ -223,7 +223,9 @@ public:
bool SkipEmptyClass() const override { return true; } bool SkipEmptyClass() const override { return true; }
int numTreePerIteration() const override { return num_class_; } int NumTreePerIteration() const override { return num_class_; }
int NumPredictOneRow() const override { return num_class_; }
private: private:
/*! \brief Number of data */ /*! \brief Number of data */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment