Commit 233f1b71 authored by Guolin Ke's avatar Guolin Ke
Browse files

add get/set c_api for leaf value

parent 1cb3aa4e
...@@ -499,6 +499,32 @@ DllExport int LGBM_BoosterDumpModel(BoosterHandle handle, ...@@ -499,6 +499,32 @@ DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
int64_t* out_len, int64_t* out_len,
char** out_str); char** out_str);
/*!
* \brief Get leaf value
* \param handle handle
* \param tree_idx index of tree
* \param leaf_idx index of leaf
* \param out_val out result
* \return 0 when succeed, -1 when failure happens
*/
DllExport int LGBM_BoosterGetLeafValue(BoosterHandle handle,
int tree_idx,
int leaf_idx,
float* out_val);
/*!
* \brief Set leaf value
* \param handle handle
* \param tree_idx index of tree
* \param leaf_idx index of leaf
* \param val leaf value
* \return 0 when succeed, -1 when failure happens
*/
DllExport int LGBM_BoosterSetLeafValue(BoosterHandle handle,
int tree_idx,
int leaf_idx,
float val);
// some help functions used to convert data // some help functions used to convert data
std::function<std::vector<double>(int row_idx)> std::function<std::vector<double>(int row_idx)>
......
...@@ -50,9 +50,14 @@ public: ...@@ -50,9 +50,14 @@ public:
double threshold_double, double left_value, double threshold_double, double left_value,
double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain); double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain);
/*! \brief Get the output of one leave */ /*! \brief Get the output of one leaf */
inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; } inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; }
/*! \brief Set the output of one leaf */
inline void SetLeafOutput(int leaf, double output) {
leaf_value_[leaf] = output;
}
/*! /*!
* \brief Adding prediction value of this tree model to scores * \brief Adding prediction value of this tree model to scores
* \param data The dataset * \param data The dataset
......
...@@ -193,6 +193,18 @@ public: ...@@ -193,6 +193,18 @@ public:
static_cast<int>(models_.size()) / num_class_); static_cast<int>(models_.size()) / num_class_);
} }
inline double GetLeafValue(int tree_idx, int leaf_idx) const {
CHECK(tree_idx >= 0 && static_cast<size_t>(tree_idx) < models_.size());
CHECK(leaf_idx >= 0 && leaf_idx < models_[tree_idx]->num_leaves());
return models_[tree_idx]->LeafOutput(leaf_idx);
}
inline void SetLeafValue(int tree_idx, int leaf_idx, double val) {
CHECK(tree_idx >= 0 && static_cast<size_t>(tree_idx) < models_.size());
CHECK(leaf_idx >= 0 && leaf_idx < models_[tree_idx]->num_leaves());
models_[tree_idx]->SetLeafOutput(leaf_idx, val);
}
/*! /*!
* \brief Get Type name of this boosting object * \brief Get Type name of this boosting object
*/ */
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <mutex> #include <mutex>
#include "./application/predictor.hpp" #include "./application/predictor.hpp"
#include "./boosting/gbdt.h"
namespace LightGBM { namespace LightGBM {
...@@ -190,6 +191,15 @@ public: ...@@ -190,6 +191,15 @@ public:
return boosting_->DumpModel(); return boosting_->DumpModel();
} }
double GetLeafValue(int tree_idx, int leaf_idx) const {
return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
}
void SetLeafValue(int tree_idx, int leaf_idx, double val) {
std::lock_guard<std::mutex> lock(mutex_);
dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
}
int GetEvalCounts() const { int GetEvalCounts() const {
int ret = 0; int ret = 0;
for (const auto& metric : train_metric_) { for (const auto& metric : train_metric_) {
...@@ -788,6 +798,29 @@ DllExport int LGBM_BoosterDumpModel(BoosterHandle handle, ...@@ -788,6 +798,29 @@ DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
API_END(); API_END();
} }
DllExport int LGBM_BoosterGetLeafValue(BoosterHandle handle,
int tree_idx,
int leaf_idx,
float* out_val) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_val = static_cast<float>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
API_END();
}
DllExport int LGBM_BoosterSetLeafValue(BoosterHandle handle,
int tree_idx,
int leaf_idx,
float val) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->SetLeafValue(tree_idx, leaf_idx, static_cast<double>(val));
API_END();
}
// ---- start of some help functions // ---- start of some help functions
std::function<std::vector<double>(int row_idx)> std::function<std::vector<double>(int row_idx)>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment