Commit 4accb9d4 authored by Guolin Ke's avatar Guolin Ke
Browse files

support merge two booster

parent 14a67b7e
......@@ -35,6 +35,11 @@ public:
const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) = 0;
/*!
* \brief Merge model from other boosting object
* \param other
*/
virtual void MergeFrom(const Boosting* other) = 0;
/*!
* \brief Reset Config for current boosting
* \param config Configs for boosting
......@@ -179,6 +184,7 @@ public:
Boosting(const Boosting&) = delete;
static void LoadFileToBoosting(Boosting* boosting, const char* filename);
/*!
* \brief Create boosting object
* \param type Type of boosting
......
......@@ -240,8 +240,18 @@ DllExport int LGBM_BoosterCreateFromModelfile(
*/
DllExport int LGBM_BoosterFree(BoosterHandle handle);
/*!
* \brief Merge model in two booster to first handle
* \param handle handle, will merge other handle to this
* \param other_handle
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
BoosterHandle other_handle);
/*!
* \brief Add new validation to booster
* \param handle handle
* \param valid_data validation data set
* \return 0 when success, -1 when failure happens
*/
......@@ -249,7 +259,8 @@ DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
const DatesetHandle valid_data);
/*!
* \brief Add new validation to booster
* \brief Reset training data for booster
* \param handle handle
* \param train_data training data set
* \return 0 when success, -1 when failure happens
*/
......@@ -258,6 +269,7 @@ DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
/*!
* \brief Reset config for current booster
* \param handle handle
* \param parameters format: 'key1=value1 key2=value2'
* \return 0 when success, -1 when failure happens
*/
......@@ -265,6 +277,7 @@ DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* param
/*!
* \brief Get number of class
* \param handle handle
* \return number of class
*/
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len);
......
......@@ -101,10 +101,6 @@ public:
/*! \brief Serialize this object by string*/
std::string ToString();
/*! \brief Disable copy */
Tree& operator=(const Tree&) = delete;
/*! \brief Disable copy */
Tree(const Tree&) = delete;
private:
/*!
* \brief Find leaf index of which record belongs by data
......
......@@ -36,6 +36,14 @@ public:
const std::vector<const Metric*>& training_metrics)
override;
void MergeFrom(const Boosting* other) override {
auto other_gbdt = reinterpret_cast<const GBDT*>(other);
for (const auto& tree : other_gbdt->models_) {
auto new_tree = std::unique_ptr<Tree>(new Tree(*(tree.get())));
models_.push_back(std::move(new_tree));
}
}
/*!
* \brief Reset Config for current boosting
* \param config Configs for boosting
......
......@@ -42,6 +42,10 @@ public:
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
}
void MergeFrom(const Booster* other) {
boosting_->MergeFrom(other->boosting_.get());
}
~Booster() {
}
......@@ -465,6 +469,14 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) {
API_END();
}
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
BoosterHandle other_handle) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
ref_booster->MergeFrom(ref_other_booster);
API_END();
}
DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
const DatesetHandle valid_data) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment